(define-module (decision-tree))

(use-modules
 ;; SRFI-1 for list procedures
 ((srfi srfi-1) #:prefix srfi1:)
 ;; SRFI-8 for `receive` form
 (srfi srfi-8)
 (utils csv-utils)
 (utils display-utils)
 (utils math-utils)
 (utils string-utils)
 (utils list-utils)
 (dataset)
 (data-point)
 (tree)
 (metrics)
 (prediction)
 (split-quality-measure))


(define FILE-PATH
  "data_banknote_authentication.csv")

;; For each column we define a column converter, which converts the string,
;; which is read in from the CSV, to an appropriate data type for the data set
;; in the program.


(define COLUMN-CONVERTERS
  (list (list string->number)
        (list string->number)
        (list string->number)
        (list string->number)
        (list string->number)))

;; Using the defined column converters, we define the data set.
(define banking-dataset
  (all-rows "data_banknote_authentication.csv" #:converters COLUMN-CONVERTERS))

;; This is an artefact from development. It serves as an example to test things
;; with interactively or in a shorter time than with a whole larger data set.
(define dev-dataset
  (list #(2.771244718 1.784783929 0)
        #(1.728571309 1.169761413 0)
        #(3.678319846 2.81281357 0)
        #(3.961043357 2.61995032 0)
        #(2.999208922 2.209014212 0)
        #(7.497545867 3.162953546 1)
        #(9.00220326 3.339047188 1)
        #(7.444542326 0.476683375 1)
        #(10.12493903 3.234550982 1)
        #(6.642287351 3.319983761 1)))


;; =======================
;; DECISION TREE ALGORITHM
;; =======================
(define-public split-data
  (lambda (data index value)
    (receive (part1 part2)
        (dataset-partition (lambda (data-point)
                             (< (data-point-get-col data-point index) value))
                           data)
      (list part1 part2))))


(define-public select-min-cost-split
  (lambda (split-a split-b)
    (if (< (split-cost split-a) (split-cost split-b))
        split-a
        split-b)))


(define-public get-best-split-for-column
  (lambda* (data
            label-column-index
            column-index
            #:key
            (split-quality-proc gini-index))
    "Calculate the best split value for the column of the data at the given
index."
    (let ([initial-placeholder-split
           (make-split 0 +inf.0 (list '() '()) +inf.0)])
      ;; TODO: Parallelism: This is a place, where parallelism could be made use
      ;; of. Instead of going through all the split values of the column
      ;; sequentially, the split values can be processed in parallel.
      (let iter-col-vals ([column-data (dataset-get-col data column-index)]
                          [previous-best-split initial-placeholder-split])
        (cond
         [(dataset-column-empty? column-data) previous-best-split]
         [else
          (let* ([current-value (dataset-column-first column-data)]
                 [current-subsets (split-data data
                                              column-index
                                              current-value)]
                 [current-cost (split-quality-proc current-subsets label-column-index)])
            (iter-col-vals
             (dataset-column-rest column-data)
             (select-min-cost-split
              previous-best-split
              ;; FUTURE TODO: Here we are creating a Split record, which might
              ;; not be needed and thrown away after this iteration. An
              ;; optimization might be to not even create it, if the current
              ;; cost is higher than the cost of the previously best
              ;; split. However, always handling multiple values bloates the
              ;; code a little and the current implementation seems more
              ;; readable.
              (make-split column-index
                          current-value
                          current-subsets
                          current-cost))))])))))


(define-public get-best-split
  (lambda* (data
            feature-column-indices
            label-column-index
            #:key
            (split-quality-proc gini-index))
    (let ([max-col-index (- (data-point-length (dataset-first data)) 1)]
          [start-column-index 0]
          [initial-placeholder-split (make-split 0 +inf.0 (list '() '()) +inf.0)])
      ;; iterate over columns -- which column is best for splitting?

      ;; TODO: Parallelism: Here we could use multiple cores to calculate the
      ;; best split for different columns in parallel.
      (let iter-col-ind ([col-index start-column-index]
                         [best-split-so-far initial-placeholder-split])
        (cond
         [(> col-index max-col-index) best-split-so-far]
         [(= col-index label-column-index)
          (iter-col-ind (+ col-index 1) best-split-so-far)]
         [else
          ;; iterate over values in 1 column -- which value is the best split
          ;; value?
          (iter-col-ind (+ col-index 1)
                        (select-min-cost-split
                         best-split-so-far
                         (get-best-split-for-column
                          data
                          label-column-index
                          col-index
                          #:split-quality-proc split-quality-proc)))])))))


(define-public fit
  (lambda* (#:key
            train-data
            (feature-column-indices '())
            label-column-index
            (max-depth 6)
            (min-data-points 12)
            (min-data-points-ratio 0.02)
            (min-impurity-split (expt 10 -7))
            (stop-at-no-impurity-improvement #t))
    (define all-data-length (dataset-length train-data))
    (define current-depth 1)

    #|
    STOP CRITERIA:
    - only one class in a subset (cannot be split any further and does not need to be split)
    - maximum tree depth reached
    - minimum number of data points in a subset
    - minimum ratio of data points in this subset
    |#
    (define all-same-label?
      (lambda (subset)
        ;; FUTURE TODO: Do no longer assume, that the label column is always an
        ;; integer or a number.
        (column-uniform? (dataset-get-col subset label-column-index) =)))

    (define insufficient-data-points-for-split?
      (lambda (subset)
        (let ([number-of-data-points (dataset-length subset)])
          (or (<= number-of-data-points min-data-points)
              (< number-of-data-points 2)))))

    (define max-depth-reached?
      (lambda (current-depth)
        (>= current-depth max-depth)))

    (define insufficient-data-points-ratio-for-split?
      (lambda (subset)
        (<= (/ (dataset-length subset) all-data-length) min-data-points-ratio)))

    (define no-improvement?
      (lambda (previous-split-impurity split-impurity)
        (and (<= previous-split-impurity split-impurity)
             stop-at-no-impurity-improvement)))

    (define insufficient-impurity?
      (lambda (impurity)
        (< impurity min-impurity-split)))
    #|
    Here we do the recursive splitting.
    |#
    (define recursive-split
      (lambda (subset current-depth previous-split-impurity)
        (display "recursive split on depth: ") (displayln current-depth)

        ;; Before splitting further, we check for stopping early conditions.
        ;; TODO: Refactor this part. This cond form is way to big. Think of
        ;; something clever.  TODO: Parallelism: This might be a place to use
        ;; parallelism at, to check for the stopping criteria in
        ;; parallel. However, I think they might not take that long to calculate
        ;; anyway and the question is, whether the overhead is worth it.
        (cond
         [(max-depth-reached? current-depth)
          (displayln "STOPPING CONDITION: maximum depth")
          (displayln (string-append "INFO: still got "
                                    (number->string (dataset-length subset))
                                    " data points"))
          (make-leaf-node subset)]
         [(insufficient-data-points-for-split? subset)
          (displayln "STOPPING CONDITION: insuficient number of data points")
          (displayln (string-append "INFO: still got "
                                    (number->string (dataset-length subset))
                                    " data points"))
          (make-leaf-node subset)]
         [(insufficient-data-points-ratio-for-split? subset)
          (displayln "STOPPING CONDITION: insuficient ratio of data points")
          (displayln (string-append "INFO: still got "
                                    (number->string (dataset-length subset))
                                    " data points"))
          (make-leaf-node subset)]
         [(all-same-label? subset)
          (displayln "STOPPING CONDITION: all same label")
          (displayln (string-append "INFO: still got "
                                    (number->string (dataset-length subset))
                                    " data points"))
          (make-leaf-node subset)]
         [else
          (displayln (string-append "INFO: CONTINUING SPLITT: still got "
                                    (number->string (dataset-length subset))
                                    " data points"))
          ;; (display "input data for searching best split:") (displayln subset)
          (let* ([best-split
                  (get-best-split subset
                                  feature-column-indices
                                  label-column-index
                                  #:split-quality-proc gini-index)])
            (cond
             [(no-improvement? previous-split-impurity (split-cost best-split))
              (displayln (string-append "STOPPING CONDITION: "
                                        "no improvement in impurity: previously: "
                                        (number->string previous-split-impurity) " "
                                        "now: "
                                        (number->string (split-cost best-split))))
              (make-leaf-node subset)]
             [(insufficient-impurity? previous-split-impurity)
              (displayln "STOPPING CONDITION: not enough impurity for splitting further")
              (make-leaf-node subset)]
             [else
              ;; Here are the recursive calls. This is not tail recursive, but
              ;; since the data structure itself is recursive and we only have
              ;; as many procedure calls as there are branches in the tree, it
              ;; is OK to not be tail recursive here.

              ;; TODO: Parallelism: Here is an obvious place to introduce
              ;; parallelism. The recursive calls to ~recursive-split~ can run
              ;; in parallel.

              ;; TODO: Abstraction: We are still using ~car~ here.
              (make-node subset
                         (split-feature-index best-split)
                         (split-value best-split)
                         (recursive-split (car (split-subsets best-split))
                                          (+ current-depth 1)
                                          (split-cost best-split))
                         (recursive-split (cadr (split-subsets best-split))
                                          (+ current-depth 1)
                                          (split-cost best-split)))]))])))
    (recursive-split train-data 1 1.0)))


(define-public cross-validation-split
  (lambda* (dataset n-folds #:key (random-seed #f))
    (let* ([shuffled-dataset (shuffle-dataset dataset #:seed random-seed)]
           [number-of-data-points (dataset-length shuffled-dataset)]
           [fold-size
            (exact-floor (/ number-of-data-points n-folds))])
      (split-into-chunks-of-size-n shuffled-dataset
                                   (exact-ceiling
                                    (/ number-of-data-points n-folds))))))


(define-public leave-one-out-k-folds
  (lambda (folds left-out-fold)
    (define leave-one-out-filter-procedure
      (lambda (fold)
        (not (equal? fold left-out-fold))))
    (filter leave-one-out-filter-procedure
            folds)))



;; evaluates the algorithm using cross validation split with n folds
(define-public evaluate-algorithm
  (lambda* (#:key
            dataset
            n-folds
            feature-column-indices
            label-column-index
            (max-depth 6)
            (min-data-points 12)
            (min-data-points-ratio 0.02)
            (min-impurity-split (expt 10 -7))
            (stop-at-no-impurity-improvement #t)
            (random-seed #f))
    "Calculate a list of accuracy values, one value for each fold of a
cross-validation split."
    ;; FUTURE TODO: Parallelism: This is up for multicore optimization, instead
    ;; of sequentially going through the folds in order. It should be relatively
    ;; simple to calculate the accuracy for each fold in a separate job.
    (let ([folds
           (cross-validation-split dataset
                                   n-folds
                                   #:random-seed random-seed)])
      (let iter ([remaining-folds folds])
        (cond
         [(null? remaining-folds) '()]
         [else
          (let ([fold (car remaining-folds)])
            (cons (let* ([train-set
                          (fold-right append
                                      empty-dataset
                                      (leave-one-out-k-folds folds fold))]
                         [test-set
                          (map (lambda (data-point)
                                 (data-point-take-features data-point
                                                           label-column-index))
                               fold)]
                         [actual-labels (dataset-get-col fold label-column-index)]
                         [tree
                          (fit #:train-data train-set
                               #:feature-column-indices feature-column-indices
                               #:label-column-index label-column-index
                               #:max-depth max-depth
                               #:min-data-points min-data-points
                               #:min-data-points-ratio min-data-points-ratio
                               #:min-impurity-split min-impurity-split
                               #:stop-at-no-impurity-improvement stop-at-no-impurity-improvement)]
                         [predicted-labels
                          (predict-dataset tree test-set label-column-index)])
                    (accuracy-metric actual-labels predicted-labels))
                  (iter (cdr remaining-folds))))])))))
